from types import MethodType
from typing import Callable, Optional, Tuple

import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import (
    LlamaAttention,
    LlamaRMSNorm,
    apply_rotary_pos_emb,
)

SUPPORT_XFORMERS = True
SUPPORT_FLASH2 = False

try:
    import xformers.ops as xops

    SUPPORT_XFORMERS = True
except ImportError:
    pass

try:
    from flash_attn import flash_attn_func

    SUPPORT_FLASH2 = True
except ImportError:
    pass

SUPPORT_FLASH = SUPPORT_XFORMERS or SUPPORT_FLASH2


def llama_flash_attention(
    self: LlamaAttention,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    bsz, q_len, _ = hidden_states.size()

    query_states = (
        self.q_proj(hidden_states)
        .view(bsz, q_len, self.num_heads, self.head_dim)
        .transpose(1, 2)
    )
    key_states = (
        self.k_proj(hidden_states)
        .view(bsz, q_len, self.num_heads, self.head_dim)
        .transpose(1, 2)
    )
    value_states = (
        self.v_proj(hidden_states)
        .view(bsz, q_len, self.num_heads, self.head_dim)
        .transpose(1, 2)
    )

    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        kv_seq_len += past_key_value[0].shape[-2]
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
    query_states, key_states = apply_rotary_pos_emb(
        query_states, key_states, cos, sin, position_ids
    )
    # [bsz, nh, t, hd]

    if past_key_value is not None:
        # reuse k, v, self_attention
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
        value_states = torch.cat([past_key_value[1], value_states], dim=2)

    past_key_value = (key_states, value_states) if use_cache else None

    # q, k, v is [B, H, S, K] and xformers need [B, S, H, K]. returns [B, S, H, K]
    query_states = query_states.transpose(1, 2)
    key_states = key_states.transpose(1, 2)
    value_states = value_states.transpose(1, 2)
    if SUPPORT_FLASH2:
        attn_output = flash_attn_func(
            query_states, key_states, value_states, causal=True
        )
    else:
        attn_output = xops.memory_efficient_attention(
            query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
        )

    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value


def llama_fast_rms_norm(self: LlamaRMSNorm, hidden_states: torch.FloatTensor):
    # return rms_norm(hidden_states, self.weight, self.variance_epsilon)
    from apex.normalization.fused_layer_norm import (
        manual_rms_norm,
        mixed_dtype_fused_rms_norm_affine,
    )

    hsz = self.weight.shape[0]
    if not input.is_cuda:
        return manual_rms_norm(hidden_states, hsz, self.weight, self.variance_epsilon)
    return mixed_dtype_fused_rms_norm_affine(
        hidden_states, self.weight, hsz, self.variance_epsilon
    )


def replace_xformers(model: nn.Module):
    for module in model.modules():
        if isinstance(module, LlamaAttention):
            module.forward = MethodType(llama_flash_attention, module)
        # if isinstance(module, LlamaRMSNorm) and rms_norm is not None and isinstance(rms_norm, Callable):
        #     module.forward = MethodType(llama_fast_rms_norm, module)
